{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transcribe 2 hours of audio in less than 2 minutes with Whisper\n"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/PrunaAI/pruna/blob/v|version|/docs/tutorials/asr_tutorial.ipynb\">\n",
    "    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial demonstrates how to use the `pruna` package to optimize any custom whisper model. In this case, the smash function wraps the model into an efficient pipeline, which will transcribe 2 hours of audio in under 2 minutes on an A100 GPU We will use the `openai/whisper-large-v3` model as an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you are not running the latest version of this tutorial, make sure to install the matching version of pruna\n",
    "# the following command will install the latest version of pruna\n",
    "\n",
    "%pip install pruna"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Loading the ASR model\n",
    "\n",
    "First, load your ASR model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForSpeechSeq2Seq\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32\n",
    "\n",
    "model_id = \"openai/whisper-large-v3\"\n",
    "\n",
    "model = AutoModelForSpeechSeq2Seq.from_pretrained(\n",
    "    model_id,\n",
    "    torch_dtype=torch_dtype,\n",
    "    use_safetensors=True,\n",
    "    low_cpu_mem_usage=True,\n",
    ")\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Initializing the Smash Config\n",
    "\n",
    "Next, initialize the smash_config. Since the compiler require a processor, we add it to the smash_config."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna import SmashConfig\n",
    "\n",
    "# Initialize the SmashConfig\n",
    "smash_config = SmashConfig([\"c_whisper\", \"whisper_s2t\"])\n",
    "smash_config.add_tokenizer(model_id)\n",
    "smash_config.add_processor(model_id)\n",
    "# uncomment the following line to quantize the model to 8 bits\n",
    "# smash_config.add({'c_whisper_weight_bits': 8})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Smashing the Model\n",
    "\n",
    "Now, smash the model. This will take approximately 2 minutes on a T4 GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna import smash\n",
    "\n",
    "# Smash the model\n",
    "smashed_model = smash(\n",
    "    model=model,\n",
    "    smash_config=smash_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Preparing the Input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "response = requests.get(\n",
    "    \"https://huggingface.co/datasets/reach-vb/random-audios/resolve/main/sam_altman_lex_podcast_367.flac\"\n",
    ")\n",
    "audio_sample = \"sam_altman_lex_podcast_367.flac\"\n",
    "\n",
    "# Save the content to the specified file\n",
    "with open(audio_sample, \"wb\") as f:\n",
    "    f.write(response.content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Running the Model\n",
    "\n",
    "Finally, run the model to transcribe the audio file. Make sure you have `ffmpeg` installed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the result\n",
    "smashed_model(audio_sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Wrap Up\n",
    "\n",
    "Congratulations! You have successfully smashed an ASR model. You can now use the `pruna` package to optimize any custom ASR model. The only parts that you should modify are step 1, 4 and 5 to fit your use case."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Pruna",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
